import numpy as np

from skimage.filters import median, threshold_multiotsu
from scipy.ndimage import binary_erosion, binary_dilation

def bkgnd_multiotsu_threshold(img, otsu_nclass, otsu_thresh_class, img_trim_int=5):
    """
    Get background threshold of an image, using multi-class Otsu's method.

    Parameters
    ----------
    img: ndarray
        2D array of grayscale image
    otsu_nclass: int
        Number of classes for multi-class Otsu thresholding
    otsu_thresh_class: int
        Which class boundary to use as background threshold
    img_trim_int: int
        Grayscale intensity, used to trim image before calculating Otsu thresholds

    Returns
    -------
    Background threshold intensity
    """

    trimmed_img = img[(img > img_trim_int) & (img < 255 - img_trim_int)]
    otsu_thresh = threshold_multiotsu(trimmed_img, classes=otsu_nclass)
    print("multi-otsu thresholds are: " + str(otsu_thresh))
    bkgnd_thresh = otsu_thresh[otsu_thresh_class]

    return bkgnd_thresh


def threshold_mask_center_vs_outline(img, masks, k, size_thld, mask_threshold_dict):
    """
    Threshold algorithm, comparing the image grayscale intensity of a mask's center vs. its outline,
    either as difference or ratio, against a pre-specified threshold. Also compares against a
    pre-specified size threshold.

    Returns mask_accept (bool), mask_avg_int (mask average intensity), mask_npix (number of pixels
    in mask).

    Parameters
    ----------
    img: ndarray of int
        2D image array in grayscale
    masks: ndarray of int
        2D array of masks
    k: int
        Index of mask being thresholded
    size_thld: float
        Size threshold for mask
    mask_threshold_dict: dict
        Contains: mask_outline_pix (pixel thickness of mask outline for thresholding, numbers < 1 set to 1,
        pixels that overlap with other candidate masks excluded); mask_center_pix (pixel thickness of mask
        center for thresholding, set to 0 to use whole mask as center); intst_thld (intensity threshold for comparing
        mask center vs. mask outline); intst_thld_type ('diff' or 'ratio').

    Returns
    -------
    mask_accept, mask_avg_int, mask_npix
    """

    mask_outline_pix = mask_threshold_dict['mask_outline_pix']
    mask_center_pix = mask_threshold_dict['mask_center_pix']
    intst_thld = mask_threshold_dict['intst_thld']
    intst_thld_type = mask_threshold_dict['intst_thld_type']

    if (intst_thld_type != 'diff') & (intst_thld_type != 'ratio'):
        raise ValueError("'intst_thld_type' needs to be either 'diff' or 'ratio'")

    img_c = img.copy()
    masks_c = masks.copy()
    mask = (masks_c == k)
    mask_npix = np.count_nonzero(mask)
    mask_avg_int = np.average(img_c[mask])

    if mask_npix < size_thld:
        mask_accept = False

    # mask intensity threshold
    else:
        # bitwise XOR to subtract to get eroded outline
        mask_edge = mask ^ binary_erosion(mask, iterations=1)

        if mask_outline_pix >= 2:
            mask_outline = mask ^ binary_dilation(mask, iterations=mask_outline_pix - 1)
            mask_outline[masks_c != 0] = False  # remove outline pixels that overlap with other masks
            mask_outline = mask_edge | mask_outline
        else:
            mask_outline = mask_edge

        if mask_center_pix >= 1:
            mask_center = binary_erosion(mask, iterations=1) ^ \
                          binary_erosion(mask, iterations=1 + mask_center_pix)
        else:
            mask_center = mask ^ mask_edge

        if mask_outline.sum() > 0:
            mask_ol_avg = np.average(img_c[mask_outline])
        else:
            mask_ol_avg = 1000  # set to arbitrary high value if no outline pixels

        if mask_center.sum() > 0:
            mask_cen_avg = np.average(img_c[mask_center])
        else:
            mask_cen_avg = 0

        mask_diff = mask_cen_avg - mask_ol_avg

        if mask_ol_avg > 0:
            mask_ratio = mask_cen_avg / mask_ol_avg
        else:
            mask_ratio = 1000  # set to arbitrary high value if outline pixels have 0 intensity

        if (intst_thld_type == 'ratio') & (mask_ratio > intst_thld):
            mask_accept = True
        elif (intst_thld_type == 'diff') & (mask_diff > intst_thld):
            mask_accept = True
        else:
            mask_accept = False

    return mask_accept, mask_avg_int, mask_npix

